import csv
import json
import os
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

from evaluation.matrics_calculator import MetricsCalculator
from utils_loc import load_config


def mask_decode(encoded_mask, image_shape=[512, 512]):
    length = image_shape[0] * image_shape[1]
    mask_array = np.zeros((length,))

    for i in range(0, len(encoded_mask), 2):
        splice_len = min(encoded_mask[i + 1], length - encoded_mask[i])
        for j in range(splice_len):
            mask_array[encoded_mask[i] + j] = 1

    mask_array = mask_array.reshape(image_shape[0], image_shape[1])
    # to avoid annotation errors in boundary
    mask_array[0, :] = 1
    mask_array[-1, :] = 1
    mask_array[:, 0] = 1
    mask_array[:, -1] = 1

    return mask_array


def calculate_metric(metrics_calculator, metric, src_image, tgt_image, src_mask, tgt_mask, src_prompt, tgt_prompt):
    if metric == "psnr":
        return metrics_calculator.calculate_psnr(src_image, tgt_image, None, None)
    if metric == "lpips":
        return metrics_calculator.calculate_lpips(src_image, tgt_image, None, None)
    if metric == "mse":
        return metrics_calculator.calculate_mse(src_image, tgt_image, None, None)
    if metric == "ssim":
        return metrics_calculator.calculate_ssim(src_image, tgt_image, None, None)
    if metric == "structure_distance":
        return metrics_calculator.calculate_structure_distance(src_image, tgt_image, None, None)
    if metric == "psnr_unedit_part":
        if (1 - src_mask).sum() == 0 or (1 - tgt_mask).sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_psnr(src_image, tgt_image, 1 - src_mask, 1 - tgt_mask)
    if metric == "lpips_unedit_part":
        if (1 - src_mask).sum() == 0 or (1 - tgt_mask).sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_lpips(src_image, tgt_image, 1 - src_mask, 1 - tgt_mask)
    if metric == "mse_unedit_part":
        if (1 - src_mask).sum() == 0 or (1 - tgt_mask).sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_mse(src_image, tgt_image, 1 - src_mask, 1 - tgt_mask)
    if metric == "ssim_unedit_part":
        if (1 - src_mask).sum() == 0 or (1 - tgt_mask).sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_ssim(src_image, tgt_image, 1 - src_mask, 1 - tgt_mask)
    if metric == "structure_distance_unedit_part":
        if (1 - src_mask).sum() == 0 or (1 - tgt_mask).sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_structure_distance(src_image, tgt_image, 1 - src_mask, 1 - tgt_mask)
    if metric == "psnr_edit_part":
        if src_mask.sum() == 0 or tgt_mask.sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_psnr(src_image, tgt_image, src_mask, tgt_mask)
    if metric == "lpips_edit_part":
        if src_mask.sum() == 0 or tgt_mask.sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_lpips(src_image, tgt_image, src_mask, tgt_mask)
    if metric == "mse_edit_part":
        if src_mask.sum() == 0 or tgt_mask.sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_mse(src_image, tgt_image, src_mask, tgt_mask)
    if metric == "ssim_edit_part":
        if src_mask.sum() == 0 or tgt_mask.sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_ssim(src_image, tgt_image, src_mask, tgt_mask)
    if metric == "structure_distance_edit_part":
        if src_mask.sum() == 0 or tgt_mask.sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_structure_distance(src_image, tgt_image, src_mask, tgt_mask)
    if metric == "clip_similarity_source_image":
        return metrics_calculator.calculate_clip_similarity(src_image, src_prompt, None)
    if metric == "clip_similarity_target_image":
        return metrics_calculator.calculate_clip_similarity(tgt_image, tgt_prompt, None)
    if metric == "clip_similarity_target_image_edit_part":
        if tgt_mask.sum() == 0:
            return "nan"
        else:
            return metrics_calculator.calculate_clip_similarity(tgt_image, tgt_prompt, tgt_mask)


if __name__ == "__main__":
    config = load_config()

    annotation_mapping_file = config.input_path
    metrics = config.metrics
    src_image_folder = config.dataset_path
    tgt_methods = config.method
    inverse_method = config.inversion.method
    edit_category_list = config.edit_category_list
    result_path = Path(config.work_dir).joinpath(f"evaluation_result_{inverse_method}-{tgt_methods}.csv")
    tgt_image_folder = Path(config.work_dir).joinpath(f"{inverse_method}-{tgt_methods}")
    metrics_calculator = MetricsCalculator(config.device)

    with open(result_path, 'w', newline="") as f:
        csv_write = csv.writer(f)

        csv_head = []
        for metric in metrics:
            csv_head.append(f"{tgt_methods}|{metric}")

        data_row = ["file_id"] + csv_head
        csv_write.writerow(data_row)

    with open(annotation_mapping_file, "r") as f:
        annotation_file = json.load(f)

    for key, item in tqdm(annotation_file.items()):
        if item["editing_type_id"] not in edit_category_list:
            continue
        # print(f"evaluating image {key} ...")
        base_image_path = item["image_path"]
        mask = mask_decode(item["mask"])
        original_prompt = item["original_prompt"].replace("[", "").replace("]", "")
        editing_prompt = item["editing_prompt"].replace("[", "").replace("]", "")

        mask = mask[:, :, np.newaxis].repeat([3], axis=2)

        src_image_path = os.path.join(src_image_folder, base_image_path)
        src_image = Image.open(src_image_path)

        evaluation_result = [key]

        # tgt_image_path = tgt_image_folder.joinpath(base_image_path.replace(".jpg", ".png"))
        tgt_image_path = tgt_image_folder.joinpath(base_image_path)
        # print(f"evluating method: {tgt_methods}")

        tgt_image = Image.open(tgt_image_path)

        if tgt_image.size[0] != tgt_image.size[1]:
            # to evaluate editing
            tgt_image = tgt_image.crop(
                (tgt_image.size[0] - 512, tgt_image.size[1] - 512, tgt_image.size[0], tgt_image.size[1]))

        for metric in metrics:
            # print(f"evluating metric: {metric}")
            evaluation_result.append(
                calculate_metric(metrics_calculator, metric, src_image, tgt_image, mask, mask, original_prompt,
                                 editing_prompt))

        with open(result_path, 'a+', newline="") as f:
            csv_write = csv.writer(f)
            csv_write.writerow(evaluation_result)


    # After all rows are written, load the CSV and compute mean (ignore "nan" strings)
    df = pd.read_csv(result_path)

    # Skip the first column (file_id)
    metric_df = df.iloc[:, 1:]

    # Convert "nan" strings to actual np.nan
    metric_df = metric_df.replace("nan", np.nan).astype(float)

    # Compute column-wise mean ignoring nan
    mean_row = ["average"] + metric_df.mean(skipna=True).tolist()

    # Append the average row to the CSV
    with open(result_path, 'a+', newline="") as f:
        csv_write = csv.writer(f)
        csv_write.writerow(mean_row)